from typing import List, Tuple

import numpy as np

from centralized_verification.shields.ahead_of_time_shield import AheadOfTimeShield
from centralized_verification.shields.shield import ShieldResult, AgentResult, AgentUpdate, T
from centralized_verification.shields.utils import decentralize_actions


class DecentralizedShieldOracle(AheadOfTimeShield):

    def get_initial_shield_state(self, state, initial_joint_obs) -> T:
        return None

    def evaluate_joint_action(self, state, _, proposed_action, __) -> Tuple[ShieldResult, None]:
        """
        This can happen independently for each agent, it is just being computed once to save time
        The big thing is that the result for one action doesn't depend on proposed_action for another
        """
        # noinspection PyTypeChecker
        priority: List[int] = np.random.permutation(np.arange(len(proposed_action))).tolist()  # Agents in random order

        action_set = self.get_action_set(state)
        decentralized_action_set = decentralize_actions(action_set, [0] * len(proposed_action), priority)

        def get_agent_result(individual_action_set, proposed_individual_action):
            if individual_action_set[proposed_individual_action]:
                return AgentResult(AgentUpdate(action=proposed_individual_action))
            else:
                return self.replace_action_agent_result(proposed_individual_action, 0)

        shield_result = [get_agent_result(ias, pia) for ias, pia in zip(decentralized_action_set, proposed_action)]

        return shield_result, None
